import os
import glob
import numpy as np
import SimpleITK as sitk

from utils.coordinate import XyzTuple, xyz2irc
from utils.candidate import getCandidateInfoDict
from config import DATA_PATH


class Ct:
    """
    Loading a CT scan produces a voxel array and a transformation from patient
    coordinates to array indices.
    """

    def __init__(self, series_uid):
        mhd_path = glob.glob(
            os.path.join(DATA_PATH, "LUNA16/subset*/{}.mhd".format(series_uid))
        )[0]

        # Implicitly consumes the .raw file in addition to the passed-in .mhd file
        ct_mhd = sitk.ReadImage(mhd_path)
        self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        self.series_uid = series_uid

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        # Converts the directions to an array, and reshapes the nine-element
        # array to its proper 3x3 matrix shape
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

        candidateInfo_list = getCandidateInfoDict()[self.series_uid]

        self.positiveInfo_list = [
            candidate_tup
            for candidate_tup in candidateInfo_list
            if candidate_tup.isNodule_bool
        ]
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
        self.positive_indexes = (
            self.positive_mask.sum(axis=(1, 2)).nonzero()[0].tolist()
        )

    def getRawCandidate(self, center_xyz, width_irc):
        """
        Crop a candidate sample out of the larger CT voxel array using the candidate
        center's array coordinate information (Index,Row,Column)
        """
        center_irc = xyz2irc(
            center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis] / 2))
            end_ndx = int(start_ndx + width_irc[axis])

            if start_ndx < 0:
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]
        pos_chunk = self.positive_mask[tuple(slice_list)]

        return ct_chunk, pos_chunk, center_irc

    # Building the ground truth data
    # We have annotated points, but we want a per-voxel mask that indicates
    # whether any given voxel is part of a nodule. We'll have to build that
    # mask ourselves from the data we have and then do some manual checking to
    # make sure the routine that builds the mask is performing well.
    def buildAnnotationMask(self, positiveInfo_list, threshold_hu=-700):
        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool_)

        for candidateInfo_tup in positiveInfo_list:
            center_irc = xyz2irc(
                candidateInfo_tup.center_xyz,
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_a,
            )
            ci = int(center_irc.index)
            cr = int(center_irc.row)
            cc = int(center_irc.col)

            # An algorithm for finding a bounding box around a lung nodule
            # If we assume that the nodule locations are roughly centered in
            # the mass, we can trace outward from that point in all three
            # dimensions until we hit low-density voxels, indicating that we've
            # reached normal lung tissue (which is mostly filled with air).
            index_radius = 2
            try:
                while (
                    self.hu_a[ci + index_radius, cr, cc] > threshold_hu
                    and self.hu_a[ci - index_radius, cr, cc] > threshold_hu
                ):
                    index_radius += 1
            except IndexError:
                index_radius -= 1

            row_radius = 2
            try:
                while (
                    self.hu_a[ci, cr + row_radius, cc] > threshold_hu
                    and self.hu_a[ci, cr - row_radius, cc] > threshold_hu
                ):
                    row_radius += 1
            except IndexError:
                row_radius -= 1

            col_radius = 2
            try:
                while (
                    self.hu_a[ci, cr, cc + col_radius] > threshold_hu
                    and self.hu_a[ci, cr, cc - col_radius] > threshold_hu
                ):
                    col_radius += 1
            except IndexError:
                col_radius -= 1

            # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
            # assert row_radius > 0
            # assert col_radius > 0

            boundingBox_a[
                ci - index_radius : ci + index_radius + 1,
                cr - row_radius : cr + row_radius + 1,
                cc - col_radius : cc + col_radius + 1,
            ] = True

        # Restricts the mask to voxels above our density threshold
        mask_a = boundingBox_a & (self.hu_a > threshold_hu)

        return mask_a
